#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest

import torch

from examples.speech_recognition.data import data_utils


class DataUtilsTest(unittest.TestCase):

    def test_normalization(self):
        sample_len1 = torch.tensor([[-0.7661, -1.3889, -2.0972, -0.9134, -0.7071, -0.9765, -0.8700, -0.8283,
                                    0.7512,  1.3211,  2.1532,  2.1174,  1.2800,  1.2633,  1.6147,  1.6322,
                                    2.0723,  3.1522,  3.2852,  2.2309,  2.5569,  2.2183,  2.2862,  1.5886,
                                    0.8773,  0.8725,  1.2662,  0.9899,  1.1069,  1.3926,  1.2795,  1.1199,
                                    1.1477,  1.2687,  1.3843,  1.1903,  0.8355,  1.1367,  1.2639,  1.4707]])
        out = data_utils.apply_mv_norm(sample_len1)
        assert not torch.isnan(out).any()
        assert (out == sample_len1).all()
